{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3362a434",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--ChnSentiCorp-4d058ef86e3db8d5\n",
      "Reusing dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--ChnSentiCorp-4d058ef86e3db8d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dba5d70e13d54e3b8b7e0f0e05b23065",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(9192,\n",
       " '选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "#定义数据集\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, split):\n",
    "        dataset = load_dataset(path='lansinuote/ChnSentiCorp', split=split)\n",
    "\n",
    "        def f(data):\n",
    "            return len(data['text']) > 30\n",
    "\n",
    "        self.dataset = dataset.filter(f)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        text = self.dataset[i]['text']\n",
    "\n",
    "        return text\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "\n",
    "len(dataset), dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e70a58c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "\n",
    "#加载字典和分词工具\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e59695a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "574\n",
      "[CLS] 屏 幕 的 长 宽 比 例 太 奇 怪 ， 看 了 很 [MASK] 爽 ！ 镜 面 感 太 强 ， 还 是 喜 欢 亚 [SEP]\n",
      "不\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 30]),\n",
       " torch.Size([16, 30]),\n",
       " torch.Size([16, 30]),\n",
       " torch.Size([16]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def collate_fn(data):\n",
    "    #编码\n",
    "    data = token.batch_encode_plus(batch_text_or_text_pairs=data,\n",
    "                                   truncation=True,\n",
    "                                   padding='max_length',\n",
    "                                   max_length=30,\n",
    "                                   return_tensors='pt',\n",
    "                                   return_length=True)\n",
    "\n",
    "    #input_ids:编码之后的数字\n",
    "    #attention_mask:是补零的位置是0,其他位置是1\n",
    "    input_ids = data['input_ids']\n",
    "    attention_mask = data['attention_mask']\n",
    "    token_type_ids = data['token_type_ids']\n",
    "\n",
    "    #把第15个词固定替换为mask\n",
    "    labels = input_ids[:, 15].reshape(-1).clone()\n",
    "    input_ids[:, 15] = token.get_vocab()[token.mask_token]\n",
    "\n",
    "    #print(data['length'], data['length'].max())\n",
    "\n",
    "    return input_ids, attention_mask, token_type_ids, labels\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "print(len(loader))\n",
    "print(token.decode(input_ids[0]))\n",
    "print(token.decode(labels[0]))\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f620d0e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 30, 768])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertModel\n",
    "\n",
    "#加载预训练模型\n",
    "pretrained = BertModel.from_pretrained('bert-base-chinese')\n",
    "\n",
    "#不训练,不需要计算梯度\n",
    "for param in pretrained.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "#模型试算\n",
    "out = pretrained(input_ids=input_ids,\n",
    "           attention_mask=attention_mask,\n",
    "           token_type_ids=token_type_ids)\n",
    "\n",
    "out.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b273bf7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 21128])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.decoder = torch.nn.Linear(768, token.vocab_size, bias=False)\n",
    "        self.bias = torch.nn.Parameter(torch.zeros(token.vocab_size))\n",
    "        self.decoder.bias = self.bias\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        with torch.no_grad():\n",
    "            out = pretrained(input_ids=input_ids,\n",
    "                             attention_mask=attention_mask,\n",
    "                             token_type_ids=token_type_ids)\n",
    "\n",
    "        out = self.decoder(out.last_hidden_state[:, 15])\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model(input_ids=input_ids,\n",
    "      attention_mask=attention_mask,\n",
    "      token_type_ids=token_type_ids).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1bd44a7c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt36/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 9.913239479064941 0.0\n",
      "0 50 7.882239818572998 0.1875\n",
      "0 100 6.802951812744141 0.125\n",
      "0 150 4.709199905395508 0.25\n",
      "0 200 5.053264141082764 0.375\n",
      "0 250 5.4513750076293945 0.375\n",
      "0 300 4.834373474121094 0.4375\n",
      "0 350 2.740715503692627 0.6875\n",
      "0 400 4.249405384063721 0.375\n",
      "0 450 2.4256627559661865 0.6875\n",
      "0 500 3.162682294845581 0.375\n",
      "0 550 2.479234457015991 0.5625\n",
      "1 0 3.3130388259887695 0.4375\n",
      "1 50 2.428676128387451 0.625\n",
      "1 100 1.8036706447601318 0.8125\n",
      "1 150 2.2352967262268066 0.625\n",
      "1 200 1.6442852020263672 0.8125\n",
      "1 250 1.8351222276687622 0.75\n",
      "1 300 1.6686547994613647 0.8125\n",
      "1 350 2.3184170722961426 0.5625\n",
      "1 400 2.627448797225952 0.5625\n",
      "1 450 2.4967753887176514 0.5625\n",
      "1 500 2.7677855491638184 0.625\n",
      "1 550 2.1934893131256104 0.6875\n",
      "2 0 1.5646746158599854 0.75\n",
      "2 50 1.2516077756881714 0.75\n",
      "2 100 0.6442171931266785 0.9375\n",
      "2 150 0.8584479093551636 0.875\n",
      "2 200 2.042940855026245 0.75\n",
      "2 250 0.7246753573417664 0.8125\n",
      "2 300 1.1209107637405396 0.75\n",
      "2 350 1.4746692180633545 0.75\n",
      "2 400 0.8257594108581543 0.875\n",
      "2 450 1.384639859199524 0.6875\n",
      "2 500 1.2145320177078247 0.8125\n",
      "2 550 1.2005428075790405 0.75\n",
      "3 0 0.8594522476196289 0.75\n",
      "3 50 0.474687784910202 0.9375\n",
      "3 100 0.5052637457847595 0.875\n",
      "3 150 0.9241865277290344 0.75\n",
      "3 200 0.6268807649612427 0.9375\n",
      "3 250 0.5165101885795593 0.875\n",
      "3 300 0.5541099905967712 0.9375\n",
      "3 350 0.6904842853546143 0.9375\n",
      "3 400 0.2985838055610657 0.9375\n",
      "3 450 1.1752042770385742 0.75\n",
      "3 500 0.7742744088172913 0.9375\n",
      "3 550 0.7787161469459534 0.75\n",
      "4 0 0.23652660846710205 1.0\n",
      "4 50 0.28316396474838257 1.0\n",
      "4 100 0.5179316997528076 0.875\n",
      "4 150 0.3183712959289551 1.0\n",
      "4 200 0.5086650252342224 0.875\n",
      "4 250 0.28567108511924744 0.9375\n",
      "4 300 0.4397639334201813 0.875\n",
      "4 350 0.23462174832820892 0.9375\n",
      "4 400 0.5601446628570557 0.875\n",
      "4 450 0.1961936503648758 1.0\n",
      "4 500 0.32445845007896423 1.0\n",
      "4 550 0.2166948765516281 0.9375\n"
     ]
    }
   ],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "#训练\n",
    "optimizer = AdamW(model.parameters(), lr=5e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "model.train()\n",
    "for epoch in range(5):\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader):\n",
    "        out = model(input_ids=input_ids,\n",
    "                    attention_mask=attention_mask,\n",
    "                    token_type_ids=token_type_ids)\n",
    "\n",
    "        loss = criterion(out, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        if i % 50 == 0:\n",
    "            out = out.argmax(dim=1)\n",
    "            accuracy = (out == labels).sum().item() / len(labels)\n",
    "\n",
    "            print(epoch, i, loss.item(), accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "275dd1b6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--ChnSentiCorp-4d058ef86e3db8d5\n",
      "Reusing dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--ChnSentiCorp-4d058ef86e3db8d5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "edd1803acb2c40229c7eea1605f1263e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[CLS] 床 发 出 吱 嘎 吱 嘎 的 声 音 ， 房 间 隔 [MASK] 太 差 ， 赠 送 的 早 餐 非 常 好 吃 。 [SEP]\n",
      "音 音\n",
      "1\n",
      "[CLS] 非 常 不 好 ， 我 们 渡 过 了 一 个 让 人 [MASK] 以 忍 受 的 纪 念 日. com / thread - 136 [SEP]\n",
      "难 难\n",
      "2\n",
      "[CLS] 定 的 商 务 大 床 房 ， 房 间 偏 小 了 ， [MASK] 过 经 济 性 酒 店 也 就 这 样 ； 环 境 [SEP]\n",
      "不 不\n",
      "3\n",
      "[CLS] 确 实 是 山 上 最 好 的 酒 店 ， 环 境 和 [MASK] 施 都 很 不 错 。 我 们 这 次 住 的 是 [SEP]\n",
      "设 设\n",
      "4\n",
      "[CLS] 这 本 书 紧 接 《 春 秋 大 义 》 ， 作 者 [MASK] 以 贯 之 地 以 浅 显 的 语 言 ， 告 诉 [SEP]\n",
      "一 一\n",
      "5\n",
      "[CLS] 非 常 不 满 这 酒 店 ， 配 不 上 5 星 。 [MASK] 一, 客 房 服 务 员 没 有 水 平, 房 [SEP]\n",
      "第 第\n",
      "6\n",
      "[CLS] 合 庆 的 商 务 单 间 可 以 堪 称 豪 华, [MASK] 施 特 别 先 进 ， 特 别 是 少 有 的 先 [SEP]\n",
      "设 设\n",
      "7\n",
      "[CLS] 这 是 我 住 过 的 最 差 的 酒 店 ， 房 间 [MASK] 味 难 闻 ， 刚 打 了 灭 蚊 药 水 ， 换 [SEP]\n",
      "气 气\n",
      "8\n",
      "[CLS] 总 体 很 满 意 ， 但 有 一 点 需 改 进 ， [MASK] 在 9 楼 入 住 ， 走 时 到 1 楼 前 台 [SEP]\n",
      "我 我\n",
      "9\n",
      "[CLS] 这 本 书 有 别 于 以 往 看 过 的 早 教 书 [MASK] ， 结 合 了 说 明 文 的 写 实 ， 散 文 [SEP]\n",
      "籍 籍\n",
      "10\n",
      "[CLS] [UNK] 用 起 来 不 习 惯 ， 速 度 慢 ， 分 区 [MASK] 烦 ， 带 了 很 多 垃 圾 软 件 ， 卸 载 [SEP]\n",
      "麻 麻\n",
      "11\n",
      "[CLS] 这 一 套 书 我 基 本 买 齐 了 ， 也 看 了 [MASK] 多 本 了 。 是 利 用 闲 暇 时 间 巩 固 [SEP]\n",
      "好 好\n",
      "12\n",
      "[CLS] 渡 假 村 周 围 景 色 不 错, 但 较 落 乡 [MASK] 硬 件 太 差, 服 务 水 准 有 待 提 高 [SEP]\n",
      ". .\n",
      "13\n",
      "[CLS] 相 对 来 说 ， 比 起 东 莞 ， 深 圳 的 酒 [MASK] ， 该 酒 店 服 务 质 量 还 是 略 有 小 [SEP]\n",
      "店 店\n",
      "14\n",
      "[CLS] 目 前 是 新 昌 最 好 的 酒 店 ， 环 境 和 [MASK] 务 质 量 都 非 常 好 ， 国 内 长 度 免 [SEP]\n",
      "服 服\n",
      "0.6875\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "def test():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader_test):\n",
    "\n",
    "        if i == 15:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "        out = out.argmax(dim=1)\n",
    "        correct += (out == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "        print(token.decode(input_ids[0]))\n",
    "        print(token.decode(labels[0]), token.decode(labels[0]))\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
